-
Notifications
You must be signed in to change notification settings - Fork 195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds AdeMAMix Optimizer to contrib
#1104
base: main
Are you sure you want to change the base?
Conversation
optax/contrib/_ademamix.py
Outdated
class ScaleByAdemamixState(NamedTuple): | ||
"""State for the Ademamix algorithm.""" | ||
|
||
count: chex.Array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment here about the shape of count is useful (since it's very specific) and commonplace in the code: https://github.com/google-deepmind/optax/blob/main/optax/_src/transform.py#L88
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh nice. I see it is
count: chex.Array # shape=(), dtype=jnp.int32.
everywhere.
optax/contrib/_ademamix.py
Outdated
alpha_scheduler: Optional[base.ScalarOrSchedule] = None, | ||
eps: float = 1e-8, | ||
) -> base.GradientTransformation: | ||
"""Rescale updates according to the Ademamix algorithm. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A brief description of AdEMAMix and how it actually operates may be useful. For example, a comparison to Adam (on which it is based).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, pseudo-code resembling https://github.com/google-deepmind/optax/blob/main/optax/_src/alias.py#L586-L594 may be helpful.
optax/contrib/_ademamix.py
Outdated
|
||
Args: | ||
b1: Exponential decay rate to track the first moment of past gradients for | ||
the first Exponential Moving Average (EMA) - same as AdamW |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is essentially duplicate information relative to "Exponential decay rate". I'd suggest mentioning EMA in a general description above, and then making the arg information resemble that of AdamW.
optax/contrib/_ademamix.py
Outdated
b3: float = 0.9999, | ||
alpha: float = 5.0, | ||
b3_scheduler: Optional[base.ScalarOrSchedule] = None, | ||
alpha_scheduler: Optional[base.ScalarOrSchedule] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not make alpha (and b3) a base.ScalarOrSchedule instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed the patter with the majority of the other classes that did not use base.ScalarOrSchedule
as the type and tried to replicate that.
Of course, the authors in the paper specifically recommend scheduling b3
and alpha
because of early training instabilities, so this is a bit of a different use case.
I will make this change.
optax/contrib/_ademamix.py
Outdated
"""State for the Ademamix algorithm.""" | ||
|
||
count: chex.Array | ||
count_m2: chex.Array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why you have multiple counts. Would just the first suffice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the first count
replicates the same entitiy in adam
. The second is for keeping track of the second momentum, so the user might want to do different things with each.
Does that make sense?
I will make an inline comment to this fact
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still not convinced that there's utility in having separate counts - they're incremented identically in the actual code. I would lean towards having a single count, but I don't think this is blocking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super cool work - reproducing results is definitely under-appreciated. This is really nice!
optax/contrib/_ademamix.py
Outdated
m1 = otu.tree_update_moment( | ||
updates, state.m1, b1, 1 | ||
) # m1 = b1 * m1 + (1-b1) * updates | ||
m2 = otu.tree_update_moment(updates, state.m2, c_b3, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this isn't done elsewhere in optax, but generally I prefer opaque constants (e.g. 1 in this case) to be passed in via kwargs. Consider this optional though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so:
m2 = otu.tree_update_moment(updates, state.m2, c_b3, order=1)
is what you are describing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly.
Hey @zcharles8 in rereading the paper I subsequently discovered the authors have a full I will email them and see if they have any qualms about me using aspects of their code as part of this PR. Do you know if there would be licensing problems with this approach? Please advise. |
Great question, probably there would from reading Apple's license... Let me send a mail to one of the authors I know. |
Thanks @vroulet I also tried sending an email to all 3 authors (but had to google their contact info so those might be outdated) Basically I just asked if they are keen to get a version into But if they don't want a version in |
Emailed with Matteo Pagliardini and he wrote:
So I went ahead and completed the PR. One open question: AdEMAMix uses a pretty bespoke scheduler for Both are now used in the |
optax/contrib/_ademamix.py
Outdated
import chex | ||
import jax.numpy as jnp | ||
import jax.tree_util as jtu | ||
from optax._src import base | ||
from optax._src import combine | ||
from optax._src import numerics | ||
from optax._src import transform | ||
|
||
import optax.tree_utils as otu | ||
from typing import NamedTuple, Tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the convention in this library is that imports from python generally go at the top, so
from typing import...
import chex
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah! Good point. (I was going to make a meta point about this repo not having any pre-commit
hooks but...).
I think that adding the schedulers to the library ( FWIW I think that this PR is really good (ie. things like typing information) and would prefer to have it pushed through (nice work @mathDR !) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great. Thank you @mathDR ! And thank you @zcharles8 for the review!
I've left a few minor comments. Mostly, it would be good to clean up the docstring of scale_by_ademamix now that you did the docstring of ademamix. And you;ll be able to resolve any remaining comments that way.
Thank you again to both of you!
optax/contrib/_ademamix.py
Outdated
eps: float = 1e-8, | ||
eps_root: float = 0.0, | ||
) -> base.GradientTransformation: | ||
"""Rescale updates according to the Ademamix algorithm. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you update this docstring too? The ademamix
docstring looks great. You may simply refer here to the docstring of ademamix
for e.g. the pseudocode and copy-paste e.g. the descriptions of the arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So to clarify: just a link to the ademamix
docstring? That is, don't follow the convention of:
Description
References:
Args:
Returns:
A `GradientTransformation` object.
?
Note: the above convention is prevalent throughout optax/_src/transform.py
for similar scale_by_X
functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use the following docstring
"""Scale updates according to the Ademamix algorithm.
See :func:`optax.contrib.ademamix.` for a full description of the algorithm.
References:
Pagliardini et al, `The AdEMAMix Optimizer: Better, Faster, Older
<https://arxiv.org/abs/2409.03137>`_, 2024
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Exponential decay rate to track the fast EMA.
b2: Exponential decay rate to track the second moment of past gradients.
b3: Exponential decay rate to track the slow EMA.
alpha: Mixing coefficient in the linear combination fo the fast and
slow EMAs.
eps: A small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: A small constant applied to denominator inside the square root (as
in RMSProp), to avoid dividing by zero when rescaling. This is needed for
instance when computing (meta-)gradients through Adam.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
Returns:
The corresponding `GradientTransformation`.
"""
optax/contrib/_ademamix.py
Outdated
import optax.tree_utils as otu | ||
|
||
class ScaleByAdemamixState(NamedTuple): | ||
"""State for the Ademamix algorithm.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a description of the attributes in the docstring?
See e.g.
Line 1448 in b8c2e13
class ScaleByLBFGSState(NamedTuple): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took a stab. LMK if there are better descriptions I could do. I (theoretically) could add the math update if you believe that would be warranted. I also want to get the description right. nu
is still the estimate of the second moment, but now the first moment is a combination of m1
and m2
so I just stated their respective EMA types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me.
I think I would say "fast EMA of first moment" rather than first moment personally (same for "slow EMA" and this would apply in the other docstrings). But if you have stronger preferences I won't argue against.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, final comments and then we'll be good to go I think
optax/contrib/_ademamix.py
Outdated
S_t &\leftarrow (m1_t, m2_t, v_t). | ||
\end{align*} | ||
|
||
Limitations: AdEMAMix consists in leveraging very old gradients. Therefore, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use the following. (the current formatting won't be well handled by docs).
.. note::
AdEMAMix consists in leveraging very old gradients. Therefore,
the method is best suited to settings where the number of iterations is
important. The paper reports on this effect in Appendix C.1.5, showing how
smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations
scenarios. Moreover, retaining gradient information over many thousands of
steps can pose a problem in domains requiring fast adaptation to a sudden
distribution shift, or general cases in which the distribution is
non-stationary.
Also, you may build the docs using (from the folder docs
)
make html-noplot
that way you can check whether the docs are well formatted.
optax/contrib/_ademamix.py
Outdated
eps: float = 1e-8, | ||
eps_root: float = 0.0, | ||
) -> base.GradientTransformation: | ||
"""Rescale updates according to the Ademamix algorithm. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use the following docstring
"""Scale updates according to the Ademamix algorithm.
See :func:`optax.contrib.ademamix.` for a full description of the algorithm.
References:
Pagliardini et al, `The AdEMAMix Optimizer: Better, Faster, Older
<https://arxiv.org/abs/2409.03137>`_, 2024
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Exponential decay rate to track the fast EMA.
b2: Exponential decay rate to track the second moment of past gradients.
b3: Exponential decay rate to track the slow EMA.
alpha: Mixing coefficient in the linear combination fo the fast and
slow EMAs.
eps: A small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: A small constant applied to denominator inside the square root (as
in RMSProp), to avoid dividing by zero when rescaling. This is needed for
instance when computing (meta-)gradients through Adam.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
Returns:
The corresponding `GradientTransformation`.
"""
optax/contrib/_ademamix.py
Outdated
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. | ||
b1: Exponential decay rate to track the fast EMA. | ||
b2: Exponential decay rate to track the second moment of past gradients. | ||
b3: Exponenital decay rate to track the slow EMA. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: Exponential
optax/contrib/_ademamix.py
Outdated
import optax.tree_utils as otu | ||
|
||
class ScaleByAdemamixState(NamedTuple): | ||
"""State for the Ademamix algorithm.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me.
I think I would say "fast EMA of first moment" rather than first moment personally (same for "slow EMA" and this would apply in the other docstrings). But if you have stronger preferences I won't argue against.
Okay thanks for the tip to render the docs. That allowed me to fix a lot of weirdness (and a LaTeX error!). I think this is good to go! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I don't think I have approval permissions but just want to say nice work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you again @mathDR !
So @vroulet I think you have to merge it? Or does another authorized user have to do it? |
The PR needs to be approved by one of the internal owners of the package (as I did), then copybara automatically syncs the PR with the internal code, produces a snapshot for another maintainer to check and once that other maintainer gives his/her approval the PR is merged. (That's why the PRs often take a bit of time to be merged even after they get approved). |
thanks for the contribution! Note that you'll need to edit |
Okay great. I updated Note, when I render the docs locally, my example renders a Is this an artifact of sphinx attempting to render the notebook? Hopefully this is a problem with my local env and will not persist in |
that's because your test is taking too long to run. you can either make it quicker, or add it to nb_execution_excludepatterns in https://github.com/google-deepmind/optax/blob/main/docs/conf.py (but then it won't be run as part of the test suite) |
This PR adds the AdaMAMix optimizer from The arxiv preprint: THE ADEMAMIX OPTIMIZER:
BETTER, FASTER, OLDER
Closes #1058
The docs have been updated, along with the relevant files.
Currently the docstrings are implemented, but further descriptions should/could be added. I will reach out to the paper authors to assist with that (if they are willing).